Skip to content

Add U8 copy operation for K16 MMA #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

aacostadiaz
Copy link
Collaborator

This PR adds the U8 copy operation that works correctly with the K16 MMA for FP8 GEMM or mixed dtype GEMM.

jiyang1011 and others added 22 commits April 7, 2025 19:12

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
# Conflicts:
#	include/cute/arch/xe_copy_1B.hpp
#	include/cute/arch/xe_copy_2B.hpp
#	include/cute/arch/xe_copy_4B.hpp

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
fix

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
# Conflicts:
#	include/cute/arch/mma_xe.hpp
@sanchitintel
Copy link
Collaborator

sanchitintel commented May 21, 2025

With FP8xFP8 GEMM, this config didn't work, but the corresponding code works for FP16xFP16 GEMM:

  using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
  using GmemTiledCopyB = XE_2D_U8x32x32_LD_V;

  using TileShape = Shape<_64, _256, _32>;

  using TiledMma =
      typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
      Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;

The compile-time error was

include/cute/atom/copy_traits_xe.hpp:78:19: error: static assertion failed due to requirement 'size(cute::Layout<cute::tuple<cute::C<16>, cute::C<8>>, cute::tuple<cute::C<0>, cute::C<1>>>{}) % size(cute::tuple<cute::C<8>, cute::C<64>>{}) == 0'
   78 |     static_assert(size(LayoutIn{}) % size(BlockShape{}) == 0);

It seems to be a bug since the shapes are correct.

Thanks!

…ked-copy

# Conflicts:
#	CMakeLists.txt
#	include/cute/arch/copy_xe_U16.hpp
#	include/cute/arch/copy_xe_U32.hpp
#	include/cute/arch/copy_xe_U4.hpp
#	include/cute/arch/copy_xe_U64.hpp
#	include/cute/arch/copy_xe_U8.hpp
#	include/cute/arch/copy_xe_builtin.hpp
#	include/cute/arch/copy_xe_spirv.hpp
#	include/cutlass/epilogue/collective/xe_epilogue.hpp
@aacostadiaz aacostadiaz removed the incremental Incremental changes label May 27, 2025
Comment on lines +217 to +228
struct XE_2D_U8x32x32_LD_N {
using BlockShape = Shape<_32, _32>;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(CUTE_ARCH_COPY_XE_ENABLED)
static_assert(sizeof(T) == 1, "Expected T to have size 1");
// detail::XeSubgroup2DBlockLoad<1, 16, 32, 2>{}(baseoffset, width, height, pitch, coord, dst);
// Use the transform (VNNI) version as it provides better performance when loading the A matrix for
// GEMM FP8 and GEMM mixed-precision types.
Copy link
Collaborator

@sanchitintel sanchitintel May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aacostadiaz,

Please help resolve a couple of doubts.

The DstLayout in atom traits for this copy atom is Layout<Shape <_16,Shape <_8, _2, _32>>, Stride<_16,Stride< _1,_128,_256>>>;, which seems to correspond to plain layout. So, does this mean that initially, when the data would be copied from global memory, it'd be transformed into VNNI layout before writing to the registers, and would later be converted to DstLayout? If yes, can you please point out where/how it's handled in the code?

Also, I don't see any shfl based instructions in the generated assembly dump, so is it possible that the shuffle (for VNNI -> plain layout conversion) may not be happening directly via lane registers -> lane registers (I understand this isn't possible on Nvidia GPUs, but is somehow possible on Intel GPUs, based on the documentation) but lane registers -> shared local memory -> lane registers?

Thanks!

cc @pengzhao-intel @yuankuns

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Copy trait is used to describe how a copy operation works so that the rest of the code can understand it. It does not change how the actual copy operation works.

In this case, for the VNNI copies the transformation happens inside the builtin/spirv function. There is no transformation inside cutlass for that. We just use these builtin/spirv functions and the copy traits describe how these functions work.

Copy link
Collaborator

@sanchitintel sanchitintel May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aacostadiaz, thanks, but I meant that since A for FP8 GEMM is being loaded in VNNI layout in this PR, and the GEMM output is correct, that seems to suggest that the layout must've been changed from VNNI to plain somewhere in the code.

In this case, for the VNNI copies the transformation happens inside the builtin/spirv function

Sorry, do you mean the VNNI -> plain transformation also happens inside the builtin? Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, XeSubgroup2DBlockLoad<1, 16, 32, 2> and XeSubgroup2DBlockLoadTransform<1, 16, 32, 2> (Transform is VNNI transformation) are loading the exact same data and we end up with the exact same values in the registers. The only difference with XeSubgroup2DBlockLoadTransform<1, 16, 32, 2> is that the packing is 32 bits, so we get 32-bit elements out of the copy operation. If you recast this into four 8-bit elements you have the exact same information as with the XeSubgroup2DBlockLoad<1, 16, 32, 2> copy

This comment was marked as resolved.

Copy link
Collaborator

@sanchitintel sanchitintel Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, since we load the data column-wise (from the POV of one work-item) with XeSubgroup2DBlockLoad, anyway, it doesn't matter whether we use XeSubgroup2DBlockLoadTransform or XeSubgroup2DBlockLoad (I haven't yet reasoned about whether or not it'd work for all relevant tile shapes, though. I'll do that later).

From https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html,

image

@cfgfung
Copy link

cfgfung commented May 28, 2025

Hi @aacostadiaz ,

vLLM team is blocked by this issue. Would you please prioritize this and merge this into the main branch?

Copy link
Collaborator

@joeatodd joeatodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure about the Layout for the new operation, which looks like it might relate to @sanchitintel's comment.

Aside from that, just a nit suggestion.

Comment on lines 640 to 644
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _32>>,
Stride< _0,Stride< _1,_128,_256>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16,Shape <_8, _2, _32>>,
Stride<_16,Stride< _1,_128,_256>>>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like XE_2D_Packed_U8x32x32_LD_N and XE_2D_U8x32x32_LD_N have the same *Layout traits. Is that expected?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check out the copy_debug tool to verify why they look similar (they were same when you commented) & will report back with any findings. Thanks!

aacostadiaz and others added 2 commits May 29, 2025 17:28
Co-authored-by: Joe Todd <[email protected]>
Co-authored-by: Tadej Ciglarič <[email protected]>
@@ -535,7 +535,7 @@ int main(int argc, const char** argv)
using ElementScale = MmaType;

// Note: XE_2D_U18x32x32_LD_N is incompatible with our bf16 MMA atoms
Copy link
Collaborator

@sanchitintel sanchitintel May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this comment seems to be obsolete now.

# Conflicts:
#	include/cute/atom/copy_traits_xe.hpp
Copy link
Collaborator

@muhammad-tanvir-1211 muhammad-tanvir-1211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants